{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 训练模型\n",
    "\n",
    "## 引入第三方包"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "from PIL import Image\n",
    "from keras import backend as K\n",
    "from keras.utils.vis_utils import plot_model\n",
    "from keras.models import *\n",
    "from keras.layers import *\n",
    "\n",
    "import glob\n",
    "import pickle\n",
    "\n",
    "import numpy as np\n",
    "import tensorflow.gfile as gfile\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义超参数和字符集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "NUMBER = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']\n",
    "LOWERCASE = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u',\n",
    "            'v', 'w', 'x', 'y', 'z']\n",
    "UPPERCASE = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U',\n",
    "           'V', 'W', 'X', 'Y', 'Z']\n",
    "\n",
    "CAPTCHA_CHARSET = NUMBER   # 验证码字符集\n",
    "CAPTCHA_LEN = 4            # 验证码长度\n",
    "CAPTCHA_HEIGHT = 60        # 验证码高度\n",
    "CAPTCHA_WIDTH = 160        # 验证码宽度\n",
    "\n",
    "TRAIN_DATA_DIR = './train-data/' # 验证码数据集目录\n",
    "TEST_DATA_DIR = './test-data/'\n",
    "\n",
    "BATCH_SIZE = 100\n",
    "EPOCHS = 10\n",
    "OPT = 'adam'\n",
    "LOSS = 'binary_crossentropy'\n",
    "\n",
    "MODEL_DIR = './model/train_demo/'\n",
    "MODEL_FORMAT = '.h5'\n",
    "HISTORY_DIR = './history/train_demo/'\n",
    "HISTORY_FORMAT = '.history'\n",
    "\n",
    "filename_str = \"{}captcha_{}_{}_bs_{}_epochs_{}{}\"\n",
    "\n",
    "# 模型网络结构文件\n",
    "MODEL_VIS_FILE = 'captcha_classfication' + '.png'\n",
    "# 模型文件\n",
    "MODEL_FILE = filename_str.format(MODEL_DIR, OPT, LOSS, str(BATCH_SIZE), str(EPOCHS), MODEL_FORMAT)\n",
    "# 训练记录文件\n",
    "HISTORY_FILE = filename_str.format(HISTORY_DIR, OPT, LOSS, str(BATCH_SIZE), str(EPOCHS), HISTORY_FORMAT)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 将 RGB 验证码图像转为灰度图"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rgb2gray(img):\n",
    "    # Y' = 0.299 R + 0.587 G + 0.114 B \n",
    "    # https://en.wikipedia.org/wiki/Grayscale#Converting_color_to_grayscale\n",
    "    return np.dot(img[...,:3], [0.299, 0.587, 0.114])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 对验证码中每个字符进行 one-hot 编码"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def text2vec(text, length=CAPTCHA_LEN, charset=CAPTCHA_CHARSET):\n",
    "    text_len = len(text)\n",
    "    # 验证码长度校验\n",
    "    if text_len != length:\n",
    "        raise ValueError('Error: length of captcha should be {}, but got {}'.format(length, text_len))\n",
    "    \n",
    "    # 生成一个形如（CAPTCHA_LEN*CAPTHA_CHARSET,) 的一维向量\n",
    "    # 例如，4个纯数字的验证码生成形如(4*10,)的一维向量\n",
    "    vec = np.zeros(length * len(charset))\n",
    "    for i in range(length):\n",
    "        # One-hot 编码验证码中的每个数字\n",
    "        # 每个字符的热码 = 索引 + 偏移量\n",
    "        vec[charset.index(text[i]) + i*len(charset)] = 1\n",
    "    return vec"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 将验证码向量解码为对应字符"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def vec2text(vector):\n",
    "    if not isinstance(vector, np.ndarray):\n",
    "        vector = np.asarray(vector)\n",
    "    vector = np.reshape(vector, [CAPTCHA_LEN, -1])\n",
    "    text = ''\n",
    "    for item in vector:\n",
    "        text += CAPTCHA_CHARSET[np.argmax(item)]\n",
    "    return text"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 适配 Keras 图像数据格式"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit_keras_channels(batch, rows=CAPTCHA_HEIGHT, cols=CAPTCHA_WIDTH):\n",
    "    if K.image_data_format() == 'channels_first':\n",
    "        batch = batch.reshape(batch.shape[0], 1, rows, cols)\n",
    "        input_shape = (1, rows, cols)\n",
    "    else:\n",
    "        batch = batch.reshape(batch.shape[0], rows, cols, 1)\n",
    "        input_shape = (rows, cols, 1)\n",
    "    \n",
    "    return batch, input_shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 读取训练集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train = []\n",
    "Y_train = []\n",
    "for filename in glob.glob(TRAIN_DATA_DIR + '*.png'):\n",
    "    X_train.append(np.array(Image.open(filename)))\n",
    "    Y_train.append(filename.lstrip(TRAIN_DATA_DIR).rstrip('.png'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 处理训练集图像"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(3956, 60, 160, 1) <class 'numpy.ndarray'>\n",
      "(60, 160, 1)\n"
     ]
    }
   ],
   "source": [
    "# list -> rgb(numpy)\n",
    "X_train = np.array(X_train, dtype=np.float32)\n",
    "# rgb -> gray\n",
    "X_train = rgb2gray(X_train)\n",
    "# normalize\n",
    "X_train = X_train / 255\n",
    "# Fit keras channels\n",
    "X_train, input_shape = fit_keras_channels(X_train)\n",
    "\n",
    "print(X_train.shape, type(X_train))\n",
    "print(input_shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 处理训练集标签"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(3956, 40) <class 'numpy.ndarray'>\n"
     ]
    }
   ],
   "source": [
    "Y_train = list(Y_train)\n",
    "\n",
    "for i in range(len(Y_train)):\n",
    "    Y_train[i] = text2vec(Y_train[i])\n",
    "\n",
    "Y_train = np.asarray(Y_train)\n",
    "\n",
    "print(Y_train.shape, type(Y_train))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 读取测试集，处理对应图像和标签"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(954, 60, 160, 1) <class 'numpy.ndarray'>\n",
      "(954, 40) <class 'numpy.ndarray'>\n"
     ]
    }
   ],
   "source": [
    "X_test = []\n",
    "Y_test = []\n",
    "for filename in glob.glob(TEST_DATA_DIR + '*.png'):\n",
    "    X_test.append(np.array(Image.open(filename)))\n",
    "    Y_test.append(filename.lstrip(TEST_DATA_DIR).rstrip('.png'))\n",
    "\n",
    "# list -> rgb -> gray -> normalization -> fit keras \n",
    "X_test = np.array(X_test, dtype=np.float32)\n",
    "X_test = rgb2gray(X_test)\n",
    "X_test = X_test / 255\n",
    "X_test, _ = fit_keras_channels(X_test)\n",
    "\n",
    "Y_test = list(Y_test)\n",
    "for i in range(len(Y_test)):\n",
    "    Y_test[i] = text2vec(Y_test[i])\n",
    "\n",
    "Y_test = np.asarray(Y_test)\n",
    "\n",
    "print(X_test.shape, type(X_test))\n",
    "print(Y_test.shape, type(Y_test))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 创建验证码识别模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 输入层\n",
    "inputs = Input(shape = input_shape, name = \"inputs\")\n",
    "\n",
    "# 第1层卷积\n",
    "conv1 = Conv2D(32, (3, 3), name = \"conv1\")(inputs)\n",
    "relu1 = Activation('relu', name=\"relu1\")(conv1)\n",
    "\n",
    "# 第2层卷积\n",
    "conv2 = Conv2D(32, (3, 3), name = \"conv2\")(relu1)\n",
    "relu2 = Activation('relu', name=\"relu2\")(conv2)\n",
    "pool2 = MaxPooling2D(pool_size=(2,2), padding='same', name=\"pool2\")(relu2)\n",
    "\n",
    "# 第3层卷积\n",
    "conv3 = Conv2D(64, (3, 3), name = \"conv3\")(pool2)\n",
    "relu3 = Activation('relu', name=\"relu3\")(conv3)\n",
    "pool3 = MaxPooling2D(pool_size=(2,2), padding='same', name=\"pool3\")(relu3)\n",
    "\n",
    "# 将 Pooled feature map 摊平后输入全连接网络\n",
    "x = Flatten()(pool3)\n",
    "\n",
    "# Dropout\n",
    "x = Dropout(0.25)(x)\n",
    "\n",
    "# 4个全连接层分别做10分类，分别对应4个字符。\n",
    "x = [Dense(10, activation='softmax', name='fc%d'%(i+1))(x) for i in range(4)]\n",
    "\n",
    "# 4个字符向量拼接在一起，与标签向量形式一致，作为模型输出。\n",
    "outs = Concatenate()(x)\n",
    "\n",
    "# 定义模型的输入与输出\n",
    "model = Model(inputs=inputs, outputs=outs)\n",
    "model.compile(optimizer=OPT, loss=LOSS, metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 查看模型摘要"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "inputs (InputLayer)             (None, 60, 160, 1)   0                                            \n",
      "__________________________________________________________________________________________________\n",
      "conv1 (Conv2D)                  (None, 58, 158, 32)  320         inputs[0][0]                     \n",
      "__________________________________________________________________________________________________\n",
      "relu1 (Activation)              (None, 58, 158, 32)  0           conv1[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "conv2 (Conv2D)                  (None, 56, 156, 32)  9248        relu1[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "relu2 (Activation)              (None, 56, 156, 32)  0           conv2[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "pool2 (MaxPooling2D)            (None, 28, 78, 32)   0           relu2[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "conv3 (Conv2D)                  (None, 26, 76, 64)   18496       pool2[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "relu3 (Activation)              (None, 26, 76, 64)   0           conv3[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "pool3 (MaxPooling2D)            (None, 13, 38, 64)   0           relu3[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "flatten_1 (Flatten)             (None, 31616)        0           pool3[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "dropout_1 (Dropout)             (None, 31616)        0           flatten_1[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "fc1 (Dense)                     (None, 10)           316170      dropout_1[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "fc2 (Dense)                     (None, 10)           316170      dropout_1[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "fc3 (Dense)                     (None, 10)           316170      dropout_1[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "fc4 (Dense)                     (None, 10)           316170      dropout_1[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "concatenate_1 (Concatenate)     (None, 40)           0           fc1[0][0]                        \n",
      "                                                                 fc2[0][0]                        \n",
      "                                                                 fc3[0][0]                        \n",
      "                                                                 fc4[0][0]                        \n",
      "==================================================================================================\n",
      "Total params: 1,292,744\n",
      "Trainable params: 1,292,744\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 模型可视化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_model(model, to_file=MODEL_VIS_FILE, show_shapes=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 3956 samples, validate on 954 samples\n",
      "Epoch 1/10\n",
      " - 149s - loss: 0.3270 - acc: 0.9000 - val_loss: 0.3247 - val_acc: 0.9000\n",
      "Epoch 2/10\n",
      " - 122s - loss: 0.3229 - acc: 0.9000 - val_loss: 0.3195 - val_acc: 0.9000\n",
      "Epoch 3/10\n",
      " - 114s - loss: 0.2987 - acc: 0.9004 - val_loss: 0.2726 - val_acc: 0.9028\n",
      "Epoch 4/10\n",
      " - 106s - loss: 0.2257 - acc: 0.9164 - val_loss: 0.2303 - val_acc: 0.9171\n",
      "Epoch 5/10\n",
      " - 103s - loss: 0.1799 - acc: 0.9337 - val_loss: 0.2171 - val_acc: 0.9209\n",
      "Epoch 6/10\n",
      " - 113s - loss: 0.1523 - acc: 0.9447 - val_loss: 0.2062 - val_acc: 0.9254\n",
      "Epoch 7/10\n",
      " - 112s - loss: 0.1383 - acc: 0.9498 - val_loss: 0.2048 - val_acc: 0.9260\n",
      "Epoch 8/10\n",
      " - 127s - loss: 0.1251 - acc: 0.9550 - val_loss: 0.2052 - val_acc: 0.9260\n",
      "Epoch 9/10\n",
      " - 161s - loss: 0.1144 - acc: 0.9587 - val_loss: 0.2013 - val_acc: 0.9285\n",
      "Epoch 10/10\n",
      " - 159s - loss: 0.1063 - acc: 0.9618 - val_loss: 0.2045 - val_acc: 0.9268\n"
     ]
    }
   ],
   "source": [
    "history = model.fit(X_train,\n",
    "                    Y_train,\n",
    "                    batch_size=BATCH_SIZE,\n",
    "                    epochs=EPOCHS,\n",
    "                    verbose=2,\n",
    "                    validation_data=(X_test, Y_test))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 预测样例"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3578\n"
     ]
    }
   ],
   "source": [
    "print(vec2text(Y_test[9]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "yy = model.predict(X_test[9].reshape(1, 60, 160, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3584\n"
     ]
    }
   ],
   "source": [
    "print(vec2text(yy))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 保存模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saved trained model at ./model/train_demo/captcha_adam_binary_crossentropy_bs_100_epochs_10.h5 \n"
     ]
    }
   ],
   "source": [
    "if not gfile.Exists(MODEL_DIR):\n",
    "    gfile.MakeDirs(MODEL_DIR)\n",
    "\n",
    "model.save(MODEL_FILE)\n",
    "print('Saved trained model at %s ' % MODEL_FILE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 保存训练过程记录"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.8999999182409898,\n",
       " 0.8999999182409898,\n",
       " 0.9004359691883121,\n",
       " 0.9163548910220062,\n",
       " 0.9337462147639423,\n",
       " 0.9447105740874071,\n",
       " 0.9498167388118555,\n",
       " 0.9550177077673323,\n",
       " 0.9587146095004916,\n",
       " 0.9618301110754601]"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "history.history['acc']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['val_loss', 'val_acc', 'loss', 'acc'])"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "history.history.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "if gfile.Exists(HISTORY_DIR) == False:\n",
    "    gfile.MakeDirs(HISTORY_DIR)\n",
    "\n",
    "with open(HISTORY_FILE, 'wb') as f:\n",
    "    pickle.dump(history.history, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "./history/train_demo/captcha_adam_binary_crossentropy_bs_100_epochs_10.history\n"
     ]
    }
   ],
   "source": [
    "print(HISTORY_FILE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
